How to fine-tune Phi-3 Vision on a custom dataset | mlnews3 – Weights & Biases
Weights & Biases, developer tools for machine learning
In this blog post, we’ll be fine-tuning Phi-3 Vision, a model is capable of synthesizing text from image data.
The main goal is to create a system that can generate accurate and meaningful textual descriptions based solely on visual inputs. This process includes fine-tuning the model with a specific dataset, optimizing its performance, and ensuring it can handle the task of converting visual information into descriptive text effectively.
Here’s what we’ll be covering:
The model
The Phi-3-Vision-128K-Instruct, a lightweight, state-of-the-art multimodal model, is at the core of this project. Part of the Phi-3 model family, it supports a context length of up to 128,000 tokens. The model was trained on a diverse dataset that includes synthetic data and carefully filtered publicly available websites, emphasizing high-quality, reasoning-intensive content. The training process included supervised fine-tuning and direct preference optimization to ensure precise adherence to instructions, as well as robust safety measures.
Our dataset
The dataset used is the DBQ/Burberry.Product.prices.United.States dataset (available on Huggingface). It includes images of Burberry products along with metadata on the products category, price, and title with a total of 3,040 rows, each representing a unique product. This dataset lets us test the model’s ability to understand and interpret visual data, generating descriptive text that capture intricate visual details and brand-specific characteristics.
Complex reasoning
One interesting aspect of this task is that the model needs to reason about prices and naming given only the image. This requires the model to not only recognize visual features but also understand their implications in terms of product value and branding. By synthesizing accurate textual descriptions from images, the project highlights the potential of integrating visual data to enhance the performance and versatility of models in real-world applications.
Phi-3 Vision architecture
The model architecture is a multimodal version of a Phi-3. It processes both text and image data, integrating these inputs into a unified sequence for comprehensive understanding and generation tasks.
The model uses separate embedding layers for text and images. Text tokens are converted into dense vectors, while images are processed through a CLIP vision model to extract feature embeddings. These image embeddings are then projected to match the text embeddings’ dimensions, ensuring they can be seamlessly integrated.
Integration of text and image embeddings
Special tokens within the text sequence indicate where the image embeddings should be inserted. During processing, these special tokens are replaced with the corresponding image embeddings, allowing the model to handle text and images as a single sequence.
Here is how we will format the prompt for our dataset, using the special <|image|> token:
<p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">text </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">f"<|user|>\n<|image_1|>What is shown in this image?<|end|><|assistant|>\nProduct: {row['title']}, Category: {row['category3_code']}, Full Price: {row['full_price']}<|end|>"</span></span></span></p>
Preparing our dataset
The dataset used in this project was sourced from Hugging Face, specifically the DBQ/Burberry.Product.prices.United.States dataset. To facilitate model training, the dataset was first loaded and converted into a Pandas DataFrame for easier manipulation.
Here is a script that will download the dataset and save the text to a CSV, and the images to a location on your local system:
<p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> pandas </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> pd</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> datasets </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> load_dataset</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> requests</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> PIL </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> io </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> BytesIO</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Function to download an image from a URL and save it locally</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">download_image</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> save_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">try</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> response </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> requests</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">get</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> response</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">raise_for_status</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Check if the request was successful</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Image</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">open</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">BytesIO</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">response</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">content</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">save</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">save_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">except</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Exception </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> e</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Failed to download {image_url}: {e}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">False</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Download the dataset from Hugging Face</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">dataset </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> load_dataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">'DBQ/Burberry.Product.prices.United.States'</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Convert the Hugging Face dataset to a Pandas DataFrame</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">df </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> dataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'train'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to_pandas</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Create directories to save the dataset and images</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">dataset_dir </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'./data/burberry_dataset'</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">images_dir </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">path</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">join</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">dataset_dir</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'images'</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">os</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">makedirs</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">images_dir</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> exist_ok</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Filter out rows where image download fails</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">filtered_rows </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">for</span></span><span data-slate-leaf="true"><span data-slate-string="true"> idx</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> row </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> df</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">iterrows</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image_url </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> row</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'imageurl'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image_name </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">f"{row['product_code']}.jpg"</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image_path </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">path</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">join</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">images_dir</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image_name</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> download_image</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> row</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'local_image_path'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image_path</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> filtered_rows</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">append</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">row</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Create a new DataFrame with the filtered rows</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">filtered_df </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> pd</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">DataFrame</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">filtered_rows</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Save the updated dataset to disk</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">dataset_path </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">path</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">join</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">dataset_dir</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'burberry_dataset.csv'</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">filtered_df</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to_csv</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">dataset_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> index</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">False</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Dataset and images saved to {dataset_dir}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p>
A crucial step in preparing the dataset involves downloading and storing the images locally.
We achieved this through a custom function, download_image, which fetched images from their URLs and saved them using the product codes as filenames. This method ensures that each image has a consistent and identifiable name, crucial for linking the images with the corresponding product data.
We created a CSV dataset by filtering out rows where image downloads failed to ensure dataset integrity. The DataFrame was updated with local paths of successfully downloaded images, directly linking product data with image files. This DataFrame was saved as a CSV, ready for the training process.
Training script
Now we’re ready to train our Phi-3 Vision model!
The script begins by initializing essential components, including the dataset, tokenizer, and model. The dataset—split into training and validation sets—ensures effective evaluation of the model’s performance during training. Additionally, we save the best validation model locally, and upload it to W&B at the end of the training run.
I always like when tutorials provide the full training script as opposed to smaller chunks, as I find it’s easier to understand the full flow of the code. Here is my training script I used to train Phi-3 Vision:
<p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">utils</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">data </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Dataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> DataLoader</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> random_split</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> transformers </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> AutoModelForCausalLM</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> AutoProcessor</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torchvision </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> transforms</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> PIL </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">optim </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> optim</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> pandas </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> pd</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> random</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">nn</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">functional </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> F</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> numpy </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> np</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torchvision</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">transforms</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">functional </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> resize</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> to_pil_image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">manual_seed</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">3</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Initialize Weights & Biases</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">run </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">init</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">project</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">"burberry-product-phi3"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> entity</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">"byyoung3"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Custom Dataset for Burberry Product Prices and Images</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">class</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">BurberryProductDataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">Dataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">__init__</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">self</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> dataframe</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_length</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image_size</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">dataframe </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> dataframe</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">padding_side </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'left'</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">max_length </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_length</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">__len__</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">self</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">len</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">dataframe</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">__getitem__</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">self</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> idx</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> row </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">dataframe</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">iloc</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">idx</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> text </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">f"<|user|>\n<|image_1|>What is shown in this image?<|end|><|assistant|>\nProduct: {row['title']}, Category: {row['category3_code']}, Full Price: {row['full_price']}<|end|>"</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image_path </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> row</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'local_image_path'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Tokenize text</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> encodings </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">text</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> truncation</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> padding</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">'max_length'</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_length</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">max_length</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">try</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Load and transform image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Image</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">open</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">convert</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">"RGB"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> self</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_transform_function</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">except</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">FileNotFoundError</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> IOError</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Skip the sample if the image is not found</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> encodings</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'pixel_values'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> encodings</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'price'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> row</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'full_price'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span><span data-slate-leaf="true"><span data-slate-string="true">key</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">tensor</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">val</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">for</span></span><span data-slate-leaf="true"><span data-slate-string="true"> key</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> encodings</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">items</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">image_transform_function</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">self</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> np</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">array</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Load dataset from disk</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">dataset_path </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'./data/burberry_dataset/burberry_dataset.csv'</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">df </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> pd</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">read_csv</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">dataset_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Initialize processor and tokenizer</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">model_id </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"microsoft/Phi-3-vision-128k-instruct"</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">processor </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> AutoProcessor</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">from_pretrained</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">model_id</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> trust_remote_code</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">tokenizer </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> processor</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Split dataset into training and validation sets</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">train_size </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">int</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">0.9</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">*</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">len</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">df</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">val_size </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">len</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">df</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true"> train_size</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">train_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_indices </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> random_split</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">range</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">len</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">df</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">train_size</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_size</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">train_indices </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> train_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">indices</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">val_indices </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">indices</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">train_df </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> df</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">iloc</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">train_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">val_df </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> df</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">iloc</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">val_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Create dataset and dataloader</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">train_dataset </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> BurberryProductDataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">train_df</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_length</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">512</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image_size</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">128</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">val_dataset </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> BurberryProductDataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">val_df</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_length</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">512</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image_size</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">128</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">train_loader </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> DataLoader</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">train_dataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch_size</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> shuffle</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">val_loader </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> DataLoader</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">val_dataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch_size</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> shuffle</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">False</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Initialize model</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">model </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> AutoModelForCausalLM</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">from_pretrained</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">model_id</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> device_map</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">"cuda"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> trust_remote_code</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch_dtype</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">"auto"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">device </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">"cuda"</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">cuda</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">is_available</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">else</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"cpu"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Optimizer</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">optimizer </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> optim</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">AdamW</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">parameters</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> lr</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">5e-5</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Training loop</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">num_epochs </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">eval_interval </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">150</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Evaluate every 'eval_interval' steps</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">loss_scaling_factor </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1000.0</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Variable to scale the loss by a certain amount</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">save_dir </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'./saved_models'</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">step </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">accumulation_steps </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">64</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Accumulate gradients over this many steps</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">not</span></span><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">path</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">exists</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">save_dir</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">makedirs</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">save_dir</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">best_val_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">float</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">'inf'</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">best_model_path </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Select 10 images from the validation set for logging</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">num_log_samples </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">10</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">log_indices </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> random</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">sample</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">range</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">len</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">val_dataset</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> num_log_samples</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">extract_price_from_predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Assuming the price is at the end of the text and separated by a space</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predicted_text </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">decode</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> skip_special_tokens</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">try</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predicted_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">float</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predicted_text</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">split</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">replace</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">','</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">''</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">except</span></span><span data-slate-leaf="true"><span data-slate-string="true"> ValueError</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predicted_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0.0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> predicted_price</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">evaluate</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">model</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_loader</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> device</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_samples</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">eval</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_images </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_gt_texts </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_pred_texts </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> table </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">Table</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">columns</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">"Image"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Ground Truth Text"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Predicted Text"</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">with</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">no_grad</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">for</span></span><span data-slate-leaf="true"><span data-slate-string="true"> i</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">enumerate</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">val_loader</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_samples </span></span><span data-slate-leaf="true"><span data-slate-string="true">and</span></span><span data-slate-leaf="true"><span data-slate-string="true"> i </span></span><span data-slate-leaf="true"><span data-slate-string="true">>=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_samples</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">break</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch </span></span><span data-slate-leaf="true"><span data-slate-string="true">is</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Skip if the batch is None</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">continue</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> input_ids </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'input_ids'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> attention_mask </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'attention_mask'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> pixel_values </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'pixel_values'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> labels </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">clone</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">detach</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> actual_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'price'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">item</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> outputs </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> attention_mask</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">attention_mask</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> pixel_values</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">pixel_values</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> labels</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">labels</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> outputs</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">loss</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">+=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">item</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Calculate price error</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predictions </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">argmax</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">outputs</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">logits</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> dim</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predicted_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> extract_price_from_predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">abs</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predicted_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true"> actual_price</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">+=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> price_error</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Log images, ground truth texts, and predicted texts</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> i </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_images</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">append</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">pixel_values</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">cpu</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">squeeze</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">numpy</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_gt_texts</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">append</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">decode</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">labels</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> skip_special_tokens</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_pred_texts</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">append</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">decode</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> skip_special_tokens</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Convert image to PIL format</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> pil_img </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> to_pil_image</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">resize</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">from_numpy</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">log_images</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">permute</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">2</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">336</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">336</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">convert</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">"RGB"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Add data to the table</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> table</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">add_data</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">Image</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">pil_img</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_gt_texts</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_pred_texts</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Log the table incrementally</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">log</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span><span data-slate-leaf="true"><span data-slate-string="true">"Evaluation Results step {}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">format</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">step</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> table</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Step"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> avg_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> total_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">/</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">i </span></span><span data-slate-leaf="true"><span data-slate-string="true">+</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># i+1 to account for the loop index</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> avg_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> total_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">/</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">i </span></span><span data-slate-leaf="true"><span data-slate-string="true">+</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">train</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> avg_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> avg_price_error</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">train</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">for</span></span><span data-slate-leaf="true"><span data-slate-string="true"> epoch </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">range</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">num_epochs</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Number of epochs</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_train_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_train_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> batch_count </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">for</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> train_loader</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> step </span></span><span data-slate-leaf="true"><span data-slate-string="true">+=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch </span></span><span data-slate-leaf="true"><span data-slate-string="true">is</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Skip if the batch is None</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">continue</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> input_ids </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'input_ids'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> attention_mask </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'attention_mask'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> pixel_values </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'pixel_values'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> labels </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">clone</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">detach</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> actual_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'price'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">float</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> outputs </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> attention_mask</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">attention_mask</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> pixel_values</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">pixel_values</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> labels</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">labels</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> outputs</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">loss</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> loss</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predictions </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">argmax</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">outputs</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">logits</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> dim</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predicted_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> extract_price_from_predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">backward</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">step </span></span><span data-slate-leaf="true"><span data-slate-string="true">%</span></span><span data-slate-leaf="true"><span data-slate-string="true"> accumulation_steps</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">==</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">for</span></span><span data-slate-leaf="true"><span data-slate-string="true"> param </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">parameters</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> param</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">grad </span></span><span data-slate-leaf="true"><span data-slate-string="true">is</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">not</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> param</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">grad </span></span><span data-slate-leaf="true"><span data-slate-string="true">/=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> accumulation_steps</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> optimizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">step</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> optimizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">zero_grad</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_train_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">+=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> total_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">item</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_train_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">+=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">abs</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predicted_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true"> actual_price</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">item</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> batch_count </span></span><span data-slate-leaf="true"><span data-slate-string="true">+=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Log batch loss to wandb</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">log</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span><span data-slate-leaf="true"><span data-slate-string="true">"Batch Loss"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> total_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">item</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Step"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Epoch: {epoch}, Step: {step}, Batch Loss: {total_loss.item()}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step </span></span><span data-slate-leaf="true"><span data-slate-string="true">%</span></span><span data-slate-leaf="true"><span data-slate-string="true"> eval_interval </span></span><span data-slate-leaf="true"><span data-slate-string="true">==</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> val_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> evaluate</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">model</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_loader</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> device</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">log_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">step </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">log</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Validation Loss"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Validation Price Error (Average)"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_price_error</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Step"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Step: {step}, Validation Loss: {val_loss}, Validation Price Error (Normalized): {val_price_error}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Save the best model</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true"><</span></span><span data-slate-leaf="true"><span data-slate-string="true"> best_val_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> best_val_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_loss</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> best_model_path </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">path</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">join</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">save_dir</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">f"best_model"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">save_pretrained</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">best_model_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> safe_serialization</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">False</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">save_pretrained</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">best_model_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> avg_train_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> total_train_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">/</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch_count</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> avg_train_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> total_train_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">/</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch_count</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">log</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Epoch"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> epoch</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Average Training Loss"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> avg_train_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Average Training Price Error"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> avg_train_price_error</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Epoch: {epoch}, Average Training Loss: {avg_train_loss}, Average Training Price Error: {avg_train_price_error}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> best_model_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> run</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">log_model</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> path</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">best_model_path</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> name</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">"phi3-v-burberry"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> aliases</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">"best"</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">finish</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p>
Gradient accumulation
To optimize the training process, the script incorporates several key techniques. Gradient accumulation is used to handle large batches efficiently, crucial given the computational demands of multimodal models. This technique allows the model to accumulate gradients over multiple steps before performing a weight update, effectively simulating a larger batch size and stabilizing the training process.
Here’s how we do gradient accumulation:
<p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">step </span></span><span data-slate-leaf="true"><span data-slate-string="true">%</span></span><span data-slate-leaf="true"><span data-slate-string="true"> accumulation_steps</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">==</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">for</span></span><span data-slate-leaf="true"><span data-slate-string="true"> param </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">parameters</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> param</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">grad </span></span><span data-slate-leaf="true"><span data-slate-string="true">is</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">not</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> param</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">grad </span></span><span data-slate-leaf="true"><span data-slate-string="true">/=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> accumulation_steps</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> optimizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">step</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> optimizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">zero_grad</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p>
Evaluation
During training, we evaluate the model’s performance at regular intervals, and save the best-performing model (based on validation loss). This approach ensures that the final model retains the most effective parameters learned during training.
Special emphasis is placed on logging the price prediction errors to monitor the model’s performance in predicting accurate prices. This detailed tracking helps in understanding how well the model is learning to predict prices based solely on images, a complex task requiring deep visual understanding.
The training loop iterates through the dataset, processing batches of text and image data. For each batch, the model’s predictions are compared with the actual data, and the loss is computed. The total loss is then backpropagated, and gradients are accumulated. After a set number of steps, the accumulated gradients are used to update the model’s weights.
Evaluation is a critical component of the training script. At specified intervals, the model is evaluated on the validation set, which gives us the ability to see how the model is generalizing to unseen data. The evaluation function calculates the average validation loss and price prediction error, logging these metrics to W&B.
The following code allows us to log evaluation metrics to W&B:
<p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step </span></span><span data-slate-leaf="true"><span data-slate-string="true">%</span></span><span data-slate-leaf="true"><span data-slate-string="true"> eval_interval </span></span><span data-slate-leaf="true"><span data-slate-string="true">==</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> val_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> evaluate</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">model</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_loader</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> device</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">log_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">step </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">log</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Validation Loss"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Validation Price Error (Average)"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_price_error</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Step"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p>
We track “price error” to evaluate the model’s accuracy in predicting product prices. By measuring the difference between predicted and actual prices, we can quantitatively identify how well the model performs at predicting prices, going beyond a tradition ‘next token’ loss, which is a bit vague for a task like price prediction. This ensures the model is reliable for real-world pricing applications.
We select a group of ten samples from the validation set track them across multiple validation checkpoints during the training process. These samples are logged to a table, including images, ground truth texts, and predicted texts, providing a comprehensive view of the model’s performance over time. These logs, along with the quantitative metrics, offer a thorough evaluation of the model’s ability to generate accurate and meaningful outputs.
Then, we create a table containing the image, ground truth labels, and the models predicted text at that particular stage in the training run, and finally log this table to W&B with the following code:
<p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">evaluate</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">model</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> val_loader</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> device</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_samples</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">eval</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_images </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_gt_texts </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_pred_texts </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> table </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">Table</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">columns</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">"Image"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Ground Truth Text"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Predicted Text"</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># init table </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">with</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">no_grad</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">for</span></span><span data-slate-leaf="true"><span data-slate-string="true"> i</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">enumerate</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">val_loader</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_samples </span></span><span data-slate-leaf="true"><span data-slate-string="true">and</span></span><span data-slate-leaf="true"><span data-slate-string="true"> i </span></span><span data-slate-leaf="true"><span data-slate-string="true">>=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> max_samples</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">break</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch </span></span><span data-slate-leaf="true"><span data-slate-string="true">is</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">None</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Skip if the batch is None</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">continue</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> input_ids </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'input_ids'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> attention_mask </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'attention_mask'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> pixel_values </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'pixel_values'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> labels </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">clone</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">detach</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> actual_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> batch</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'price'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">item</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> outputs </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">input_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> attention_mask</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">attention_mask</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> pixel_values</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">pixel_values</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> labels</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">labels</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> outputs</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">loss</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">+=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">item</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Calculate price error</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predictions </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">argmax</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">outputs</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">logits</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> dim</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> predicted_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> extract_price_from_predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">abs</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predicted_price </span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true"> actual_price</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> total_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">+=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> price_error</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Log images, ground truth texts, and predicted texts</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> i </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_indices</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_images</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">append</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">pixel_values</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">cpu</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">squeeze</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">numpy</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_gt_texts</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">append</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">decode</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">labels</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> skip_special_tokens</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> log_pred_texts</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">append</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">decode</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">predictions</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> skip_special_tokens</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Convert image to PIL format</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> pil_img </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> to_pil_image</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">resize</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">from_numpy</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">log_images</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">permute</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">2</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">336</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">336</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">convert</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">"RGB"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Add data to the table</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> table</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">add_data</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">Image</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">pil_img</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_gt_texts</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> log_pred_texts</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Log the table incrementally</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">log</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span><span data-slate-leaf="true"><span data-slate-string="true">"Evaluation Results step {}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">format</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">step</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> table</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"Step"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> step</span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> avg_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> total_loss </span></span><span data-slate-leaf="true"><span data-slate-string="true">/</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">i </span></span><span data-slate-leaf="true"><span data-slate-string="true">+</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># i+1 to account for the loop index</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> avg_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> total_price_error </span></span><span data-slate-leaf="true"><span data-slate-string="true">/</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">i </span></span><span data-slate-leaf="true"><span data-slate-string="true">+</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">train</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> avg_loss</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> avg_price_error</span></span></span></p>
Model logging
The best model—determined by the lowest validation loss—is saved locally. After the training is complete, this best model is then logged to W&B using run.log_model(). Logging our model is a great way to ensure we have a central location for all of our models.
By the end of the training process, the model demonstrates a huge improvement in its capability to generate product metadata given images, effectively predicting product prices along with other product information like title and category.
Here are the logs for my training run:
Utilizing W&B model registry
Once we have saved our model to W&B, we can add it to our W&B Model Registry. The Model Registry in W&B is a centralized repository that allows us to manage and version our machine learning models. It helps track model lineage, compare different versions, and deploy the best-performing models seamlessly.
First, we navigate to the artifacts pane shown in our run page. You should see a row that looks like this:
Click this row, and you will be redirected to another page that looks like this:
At the top right, you will see a button called “Link to registry” which will allow us to add the model to our model registry. After clicking this button, you will be presented with the option to add it to an existing model, or create a new model. Assuming you have not created the model, simply click register a new model.
After registering the model, navigate to the model registry page, and you will see your model:
Running inference with Phi-3 Vision
Now that we have successfully added the model to our model registry, we can now access our saved model programmatically, and run inference.
Here’s the script that will allow us to accomplish this:
<p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> weave</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> os</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> torch</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> transformers </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> AutoModelForCausalLM</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> AutoProcessor</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> PIL </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> requests</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> io </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> BytesIO</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> base64</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">from</span></span><span data-slate-leaf="true"><span data-slate-string="true"> pathlib </span></span><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Path</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">import</span></span><span data-slate-leaf="true"><span data-slate-string="true"> wandb </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Initialize Weights & Biases run</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">run </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> wandb</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">init</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">project</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">'burberry-product-price-prediction'</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">artifact </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> run</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">use_artifact</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">'byyoung3/model-registry/phi3-v-burberry:v0'</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">type</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">'model'</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">artifact_dir </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> artifact</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">download</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Artifact downloaded to: {artifact_dir}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">model_id </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"microsoft/Phi-3-vision-128k-instruct"</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">try</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> model </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> AutoModelForCausalLM</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">from_pretrained</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> artifact_dir</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> torch_dtype</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">torch</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">float16</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> attn_implementation</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">"flash_attention_2"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> trust_remote_code</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> processor </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> AutoProcessor</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">from_pretrained</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">model_id</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> trust_remote_code</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">except</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Exception </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> e</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Error loading model or processor: {e}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">raise</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Ensure the model is on the correct device</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">device </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'cuda'</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Function to convert image to data URL</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">EXT_TO_MIMETYPE </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'.jpg'</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'image/jpeg'</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'.jpeg'</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'image/jpeg'</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'.png'</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'image/png'</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'.svg'</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'image/svg+xml'</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">}</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">image_to_data_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Image</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">Image</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> ext</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">str</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">></span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">str</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> ext </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> ext</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">lower</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> ext </span></span><span data-slate-leaf="true"><span data-slate-string="true">not</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> EXT_TO_MIMETYPE</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> ext </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'.jpg'</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Default to .jpg if extension is not recognized</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> mimetype </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> EXT_TO_MIMETYPE</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">ext</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> buffered </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> BytesIO</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image_format </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'JPEG'</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">if</span></span><span data-slate-leaf="true"><span data-slate-string="true"> ext </span></span><span data-slate-leaf="true"><span data-slate-string="true">in</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'.jpg'</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">'.jpeg'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">else</span></span><span data-slate-leaf="true"><span data-slate-string="true"> ext</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">replace</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">'.'</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">''</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">upper</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">save</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">buffered</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">format</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_format</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> encoded_string </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> base64</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">b64encode</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">buffered</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">getvalue</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">decode</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">'utf-8'</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> data_url </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">f"data:{mimetype};base64,{encoded_string}"</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> data_url</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Function to run inference on a single image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">@weave.op</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">def</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">run_inference</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">str</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">-</span></span><span data-slate-leaf="true"><span data-slate-string="true">></span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">dict</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">try</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> prompt </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"<|user|>\n<|image_1|>What is shown in this image?<|end|><|assistant|>\n"</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Load image</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> image </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Image</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">open</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">requests</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">get</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> stream</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">raw</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> ext </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Path</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">suffix</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Convert image to data URL</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> data_url </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> image_to_data_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> ext</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> inputs </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> processor</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">prompt</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">image</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> return_tensors</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">"pt"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">to</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">device</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> generation_args </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"max_new_tokens"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">500</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"temperature"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">0.0</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"do_sample"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">False</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> generate_ids </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> model</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">generate</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">**</span></span><span data-slate-leaf="true"><span data-slate-string="true">inputs</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> eos_token_id</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">processor</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">tokenizer</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">eos_token_id</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">**</span></span><span data-slate-leaf="true"><span data-slate-string="true">generation_args</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true"># Remove input tokens </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> generate_ids </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> generate_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> inputs</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'input_ids'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">shape</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">1</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> response_text </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> processor</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">batch_decode</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">generate_ids</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> skip_special_tokens</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">True</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> clean_up_tokenization_spaces</span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true">False</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">0</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">return</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">{</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"predicted_text"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> response_text</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"image_data_url"</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span><span data-slate-leaf="true"><span data-slate-string="true"> data_url</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">}</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">except</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Exception </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> e</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Error during inference: {e}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">raise</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Initialize Weave project</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">weave</span></span><span data-slate-leaf="true"><span data-slate-string="true">.</span></span><span data-slate-leaf="true"><span data-slate-string="true">init</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">'burberry-product-price-prediction'</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"># Example usage</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">image_url </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">"https://assets.burberry.com/is/image/Burberryltd/1C09D316-7A71-472C-8877-91CEFBDB268A?$BBY_V3_SL_1$&wid=1501&hei=1500"</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">try</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> result </span></span><span data-slate-leaf="true"><span data-slate-string="true">=</span></span><span data-slate-leaf="true"><span data-slate-string="true"> run_inference</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">image_url</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">"Predicted Text:"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> result</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'predicted_text'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">"Image Data URL:"</span></span><span data-slate-leaf="true"><span data-slate-string="true">,</span></span><span data-slate-leaf="true"><span data-slate-string="true"> result</span></span><span data-slate-leaf="true"><span data-slate-string="true">[</span></span><span data-slate-leaf="true"><span data-slate-string="true">'image_data_url'</span></span><span data-slate-leaf="true"><span data-slate-string="true">]</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true">except</span></span><span data-slate-leaf="true"><span data-slate-string="true"> Exception </span></span><span data-slate-leaf="true"><span data-slate-string="true">as</span></span><span data-slate-leaf="true"><span data-slate-string="true"> e</span></span><span data-slate-leaf="true"><span data-slate-string="true">:</span></span></span></p><p><span data-slate-node="text"><span data-slate-leaf="true"><span data-slate-string="true"> </span></span><span data-slate-leaf="true"><span data-slate-string="true">print</span></span><span data-slate-leaf="true"><span data-slate-string="true">(</span></span><span data-slate-leaf="true"><span data-slate-string="true">f"Error running inference: {e}"</span></span><span data-slate-leaf="true"><span data-slate-string="true">)</span></span></span></p>
The provided script initializes a Weights & Biases run and downloads a model artifact. The script includes a function, run_inference, which takes an image URL, fetches the image, processes it, and runs inference using the loaded model. The function constructs a prompt, processes the image and prompt into input tensors, and generates a textual response from the model. An example usage demonstrates running inference on a Burberry product image, printing the model’s predicted text.
Additionally, we have integrated W&B Weave into our inference script, which logs outputs of our model! Weave is a lightweight toolkit by Weights & Biases for tracking and evaluating language model (LLM) applications. By decorating Python functions with @weave.op(), Weave helps log and debug model inputs and outputs, build evaluations, and organize information from experimentation to production. In our example, we use Weave to load a model from Weights & Biases, run inference on images, and convert these images to string representations which enables us to use Weave to log the images as well.
Here’s what it looks like in Weave after running our inference script:
Slack integration
Now, I want to show off a really cool feature of W&B. Hypothetically, lets say we have a another team responsible for model evaluation, and we would like to notify the eval team every time we upload a new model to our W&B model registry, so that the team can begin evaluating the model. Additionally, let’s assume the team uses Slack. W&B provides an awesome integration with Slack and Model Registry, so we can automate the process of letting our evaluation team know that a new model is ready to be evaluated! Simply click the registered model in the model registry page, and you will be presented with the following:
By clicking the “Connect Slack” button, you will be able to connect the registry to a slack channel of your choosing! When new models get added to your registry, you will now get Slack notifications:
Conclusion
This project demonstrates the capability of the Phi-3-Vision-128K-Instruct model in processing and synthesizing text from image data. The model’s ability to generate accurate and meaningful textual descriptions from visual inputs is a testament to its sophisticated design and training.
By working with a dataset that includes detailed information on Burberry products—encompassing categories, images, prices, and titles—the model has shown it can understand and interpret visual data, even making inferences about prices and product naming from images alone. This task is particularly intriguing as it requires the model to not only recognize and process visual features but also to understand their implications in terms of product value and branding.
Moreover, the project highlights the seamless integration with Weights & Biases for model artifact management and tracking model predictions during inference. By saving the best model to the W&B model registry, it ensures easy access and version control. The integration with Slack further enhances the workflow, allowing for automated notifications whenever a new model is uploaded to the registry. This automated communication ensures that updates are promptly shared, streamlining the process. Overall, I hope you enjoyed this tutorial, and if you are interested in the source code, you can find it here!